{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "a885cf5d-c525-4f5b-a8e4-f67d2f699909",
   "metadata": {},
   "source": [
    "## Copyright 2023 Google LLC"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d891d022-8979-40d4-848f-ecb84c17f12c",
   "metadata": {
    "jp-MarkdownHeadingCollapsed": true
   },
   "outputs": [],
   "source": [
    "# Copyright 2023 Google LLC\n",
    "#\n",
    "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
    "# you may not use this file except in compliance with the License.\n",
    "# You may obtain a copy of the License at\n",
    "#\n",
    "#      http://www.apache.org/licenses/LICENSE-2.0\n",
    "#\n",
    "# Unless required by applicable law or agreed to in writing, software\n",
    "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
    "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
    "# See the License for the specific language governing permissions and\n",
    "# limitations under the License."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "540d8642-c203-471c-a66d-0d43aabb0706",
   "metadata": {},
   "source": [
    "# StyleAligned over SDXL from input image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "483d0cf9",
   "metadata": {},
   "source": [
    "#### Model Load "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23d54ea7-f7ab-4548-9b10-ece87216dc18",
   "metadata": {},
   "outputs": [],
   "source": [
    "from diffusers import StableDiffusionXLPipeline, DDIMScheduler\n",
    "import torch\n",
    "import mediapy\n",
    "import sa_handler\n",
    "import math\n",
    "\n",
    "\n",
    "scheduler = DDIMScheduler(\n",
    "    beta_start=0.00085, beta_end=0.012, beta_schedule=\"scaled_linear\",\n",
    "    clip_sample=False, set_alpha_to_one=False)\n",
    "\n",
    "pipeline = StableDiffusionXLPipeline.from_pretrained(\n",
    "    \"stabilityai/stable-diffusion-xl-base-1.0\", torch_dtype=torch.float16, variant=\"fp16\",\n",
    "    use_safetensors=True,\n",
    "    scheduler=scheduler\n",
    ").to(\"cuda\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c09b1a68",
   "metadata": {
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "#### Ref image load and inversion"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f4717854",
   "metadata": {
    "pycharm": {
     "name": "#%%\n"
    }
   },
   "outputs": [],
   "source": [
    "# DDIM inversion\n",
    "\n",
    "from diffusers.utils import load_image\n",
    "import inversion\n",
    "import numpy as np\n",
    "\n",
    "src_style = \"medieval painting\"\n",
    "src_prompt = f'Man laying in a bed, {src_style}.'\n",
    "image_path = './example_image/medieval-bed.jpeg'\n",
    "\n",
    "num_inference_steps = 50\n",
    "x0 = np.array(load_image(image_path).resize((1024, 1024)))\n",
    "zts = inversion.ddim_inversion(pipeline, x0, src_prompt, num_inference_steps, 2)\n",
    "mediapy.show_image(x0, title=\"innput reference image\", height=256)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1751c4fe",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompts = [\n",
    "    src_prompt,\n",
    "    \"A man working on a laptop\",\n",
    "    \"A man eats pizza\",\n",
    "    \"A woman playig on saxophone\",\n",
    "]\n",
    "\n",
    "# some parameters you can adjust to control fidelity to reference\n",
    "shared_score_shift = np.log(2)  # higher value induces higher fidelity, set 0 for no shift\n",
    "shared_score_scale = 1.0  # higher value induces higher, set 1 for no rescale\n",
    "\n",
    "# for very famouse images consider supressing attention to refference, here is a configuration example:\n",
    "# shared_score_shift = np.log(1)\n",
    "# shared_score_scale = 0.5\n",
    "\n",
    "for i in range(1, len(prompts)):\n",
    "    prompts[i] = f'{prompts[i]}, {src_style}.'\n",
    "\n",
    "handler = sa_handler.Handler(pipeline)\n",
    "sa_args = sa_handler.StyleAlignedArgs(\n",
    "    share_group_norm=True, share_layer_norm=True, share_attention=True,\n",
    "    adain_queries=True, adain_keys=True, adain_values=False,\n",
    "    shared_score_shift=shared_score_shift, shared_score_scale=shared_score_scale,)\n",
    "handler.register(sa_args)\n",
    "\n",
    "zT, inversion_callback = inversion.make_inversion_callback(zts, offset=5)\n",
    "\n",
    "g_cpu = torch.Generator(device='cpu')\n",
    "g_cpu.manual_seed(10)\n",
    "\n",
    "latents = torch.randn(len(prompts), 4, 128, 128, device='cpu', generator=g_cpu,\n",
    "                      dtype=pipeline.unet.dtype,).to('cuda:0')\n",
    "latents[0] = zT\n",
    "\n",
    "images_a = pipeline(prompts, latents=latents,\n",
    "                    callback_on_step_end=inversion_callback,\n",
    "                    num_inference_steps=num_inference_steps, guidance_scale=10.0).images\n",
    "\n",
    "handler.remove()\n",
    "mediapy.show_images(images_a, titles=[p[:-(len(src_style) + 3)] for p in prompts])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
